import os
import random

import numpy as np
from PIL import Image
from tqdm import tqdm

#-------------------------------------------------------#
#   想要增加测试集修改trainval_percent 
#   修改train_percent用于改变验证集的比例 9:1
#   
#   当前该库将测试集当作验证集使用，不单独划分测试集
#-------------------------------------------------------#
trainval_percent    = 1
train_percent       = 0.9
#-------------------------------------------------------#
#   指向VOC数据集所在的文件夹
#   默认指向根目录下的VOC数据集
#-------------------------------------------------------#
VOCdevkit_path      = 'D:/AI_PYCODE/deeplabv3-plus-pytorch-main/deeplabv3-plus-pytorch-main/VOCdevkit'

if __name__ == "__main__":
    random.seed(0)
    print("Generate txt in ImageSets.")
    
    # Using absolute paths based on the specified VOCdevkit_path
    segfilepath     = os.path.join(VOCdevkit_path, 'datasets', 'SegmentationClass')
    saveBasePath    = os.path.join(VOCdevkit_path, 'datasets', 'ImageSets', 'Segmentation')

    # Ensure the save paths exist
    if not os.path.exists(saveBasePath):
        os.makedirs(saveBasePath)
    
    # Checking if the segmentation path exists
    if not os.path.exists(segfilepath):
        raise FileNotFoundError(f"Segmentation path not found: {segfilepath}")

    temp_seg = os.listdir(segfilepath)
    total_seg = [seg for seg in temp_seg if seg.endswith(".png")]

    num     = len(total_seg)  
    list    = range(num)  
    tv      = int(num * trainval_percent)  
    tr      = int(tv * train_percent)  
    trainval= random.sample(list, tv)  
    train   = random.sample(trainval, tr)  
    
    print("train and val size", tv)
    print("train size", tr)
    with open(os.path.join(saveBasePath, 'trainval.txt'), 'w') as ftrainval, \
         open(os.path.join(saveBasePath, 'test.txt'), 'w') as ftest, \
         open(os.path.join(saveBasePath, 'train.txt'), 'w') as ftrain, \
         open(os.path.join(saveBasePath, 'val.txt'), 'w') as fval:
        
        for i in list:  
            name = total_seg[i][:-4] + '\n'  
            if i in trainval:  
                ftrainval.write(name)  
                if i in train:  
                    ftrain.write(name)  
                else:  
                    fval.write(name)  
            else:  
                ftest.write(name)  
    
    print("Generate txt in ImageSets done.")

    print("Check datasets format, this may take a while.")
    print("检查数据集格式是否符合要求，这可能需要一段时间。")
    classes_nums = np.zeros([256], dtype=int)
    for i in tqdm(list):
        name = total_seg[i]
        png_file_name = os.path.join(segfilepath, name)
        if not os.path.exists(png_file_name):
            raise ValueError(f"未检测到标签图片{png_file_name}，请查看具体路径下文件是否存在以及后缀是否为png。")
        
        png = np.array(Image.open(png_file_name), np.uint8)
        if len(np.shape(png)) > 2:
            print(f"标签图片{name}的shape为{str(np.shape(png))}，不属于灰度图或者八位彩图，请仔细检查数据集格式。")

        classes_nums += np.bincount(np.reshape(png, [-1]), minlength=256)
            
    print("打印像素点的值与数量。")
    print('-' * 37)
    print("| %15s | %15s |" % ("Key", "Value"))
    print('-' * 37)
    for i in range(256):
        if classes_nums[i] > 0:
            print("| %15s | %15s |" % (str(i), str(classes_nums[i])))
            print('-' * 37)
    
    if classes_nums[255] > 0 and classes_nums[0] > 0 and np.sum(classes_nums[1:255]) == 0:
        print("检测到标签中像素点的值仅包含0与255，数据格式有误。")
        print("二分类问题需要将标签修改为背景的像素点值为0，目标的像素点值为1。")
    elif classes_nums[0] > 0 and np.sum(classes_nums[1:]) == 0:
        print("检测到标签中仅仅包含背景像素点，数据格式有误，请仔细检查数据集格式。")

    print("JPEGImages中的图片应当为.jpg文件、SegmentationClass中的图片应当为.png文件。")
    print("如果格式有误，参考:")
    print("https://github.com/bubbliiiing/segmentation-format-fix")
